{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "937c4760-7ea9-43ce-8e53-77ab5c92f3b7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "* running on cpu\n"
     ]
    }
   ],
   "source": [
    "########################################################################################################\n",
    "# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM\n",
    "########################################################################################################\n",
    "\n",
    "import numpy as np\n",
    "import types\n",
    "import copy\n",
    "import torch\n",
    "from torch.nn import functional as F\n",
    "\n",
    "from model import RWKV_RNN\n",
    "\n",
    "np.set_printoptions(precision=4, suppress=True, linewidth=200)"
   ]
  },
  {
   "cell_type": "raw",
   "id": "e4e4f51e-b01b-45ac-8418-4ba32460e305",
   "metadata": {},
   "source": [
    "模型下载地址（本脚本请用如下169m参数的，不要用这个链接）：https://hf-mirror.com/BlinkDL/rwkv-3-pile-430m/resolve/main/RWKV-3-Pile-430M-20220817-10602.pth?download=true\n",
    "模型下载地址（用这个）：https://hf-mirror.com/BlinkDL/rwkv-3-pile-169m/resolve/main/RWKV-3-Pile-20220720-10704.pth?download=true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cd2655e8-ed8f-4d2d-9852-5101cb0bf70e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---> Edit src/model.py to set MODEL_NAME and CPU / CUDA mode <---\n",
    "\n",
    "TEMPERATURE = 1.0\n",
    "TOP_P = 0.7\n",
    "\n",
    "DEBUG_DEBUG = False\n",
    "LENGTH_OF_EACH = 222\n",
    "NUM_TRIALS = 100\n",
    "\n",
    "context = '\\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.'\n",
    "\n",
    "##############################################################################################################"
   ]
  },
  {
   "cell_type": "raw",
   "id": "97d34a8c-b8c6-4444-9c7c-61458b1663f9",
   "metadata": {},
   "source": [
    "请在model.py中修改模型路径\n",
    "例如我的路径是/data1/ckw/RWKV-3-Pile-430M-20220817-10602\n",
    "如果想使用cuda加速，请参考：https://github.com/BlinkDL/RWKV-v2-RNN-Pile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ebd826cd-633e-4c3d-aa15-daee44c438ae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "loading RWKV-RNN /data1/ckw/RWKV-3-Pile-20220720-10704\n",
      "\n",
      "--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n",
      "\n"
     ]
    }
   ],
   "source": [
    "##############################################################################################################\n",
    "\n",
    "model = RWKV_RNN()\n",
    "print('\\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "60170c2d-be0b-43ef-9ed2-7cc5fbd1bab1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_logits(out, temperature=1.0, top_p=None):\n",
    "    probs = F.softmax(torch.tensor(out), dim=-1)\n",
    "    sorted_probs, _ = torch.sort(probs, descending=True)\n",
    "\n",
    "    cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()\n",
    "    cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n",
    "    probs[probs < cutoff] = 0\n",
    "\n",
    "    if temperature != 1.0:\n",
    "        probs = probs.pow(1.0 / temperature)\n",
    "\n",
    "    return torch.multinomial(probs, num_samples=1)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a7465358-f05a-420a-b3af-69304d8de9f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_TRIALS=3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fdeb0be0-da7d-4d2b-abf1-54b5909bbb99",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The group has its own website which is the second largest data whalechina.com and it has been downloaded over 11 million times. It has a huge reach on social media and has over 2.5 million members. It was founded in September 2010 by co-founder and CEO Wang Tianjian.\n",
      "\n",
      "Anytime, when the topic is changing, people must change their minds and make a decision. It’s no wonder that the internet has started to get heated in the past years. There is no other way to know that this information will not get leaked. It is an open secret that every member of the AI and Deep Learning community will benefit from this new social media platform.\n",
      "\n",
      "Although the data whalechina.com is not meant to be a home for AI researchers, it is still an important tool in the machine learning community and an important part of the Internet’s knowledge and knowledge. It is also a platform that has been used by a wide variety of people to learn about the topic.\n",
      "\n",
      "One of the most significant pieces of AI research was done by\n",
      "----------------------------------------------------------------------\n",
      "DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. Datawhalechina helps you learn the concepts of artificial intelligence and analyze it. It offers a computer program to analyze a series of lectures and can analyze any situation to make a decision.\n",
      "\n",
      "Hongwei Online School offers computer science and engineering programs in a complex and energetic environment. It provides online courses, design software, online tutorials, and software development.\n",
      "\n",
      "Goethe University offers online classes in artificial intelligence, machine learning, and the latest trends in technology.\n",
      "\n",
      "About the project: Goethe University is the German think tank for technology, technology, technology, and society. We publish several books, several books, a web page, and a talk show. We make educational and scientific innovations in a collaborative fashion, but there is no way to control the movement of knowledge.\n",
      "\n",
      "We will use your feedback to improve the program, and improve our service.\n",
      "\n",
      "Goethe University is a pioneer in artificial intelligence in Germany. We help organizations achieve their goals through innovation, innovative methods, and positive collaboration.\n",
      "\n",
      "Here at Goethe University, we welcome new students from around the world\n",
      "----------------------------------------------------------------------\n",
      "DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.\n",
      "\n",
      "And here is how they did it:\n",
      "\n",
      "> This application is free and open-source and allows you to build and distribute the software, the system software and the platform. The software is free and open-source.\n",
      "\n",
      "> It has been developed by the authors and has been published in the USENIX Internet Journal.\n",
      "\n",
      "> The main purpose of this platform is to build software that will enable human learning and data sharing in many different fields, including AI, robotics, energy, the Internet of Things and more.\n",
      "\n",
      "There is also a section about open source software and the project that is open-source. This is the team’s final article.\n",
      "\n",
      "## Future Work\n",
      "\n",
      "This project is for all of you who are learning and teaching, in particular, the data-scientists and researchers that are working in the field. The author, Joao, has written a lot of open source software. In this project, he plans to do the following:\n",
      "\n",
      "- Create a site to collect data on data science and robotics.\n",
      "\n",
      "- Publ\n",
      "----------------------------------------------------------------------"
     ]
    }
   ],
   "source": [
    "for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):\n",
    "    ctx = [model.tokenizer.encode(context)][0]\n",
    "    src_len = len(ctx)\n",
    "    print(context, end='')\n",
    "\n",
    "    model.clear()\n",
    "    if TRIAL == 0: # build the RNN hidden state?\n",
    "        init_state = types.SimpleNamespace()\n",
    "        for i in range(src_len if DEBUG_DEBUG else src_len):\n",
    "            x = ctx[:i+1]\n",
    "            if i == src_len - 1:\n",
    "                init_state.out = model.run(x)\n",
    "            else:\n",
    "                model.run(x)\n",
    "        model.save(init_state)\n",
    "    else:\n",
    "        model.load(init_state)\n",
    "\n",
    "    if DEBUG_DEBUG:\n",
    "        out = init_state.out\n",
    "        print('\\n', np.array(x), '==>', np.array(\n",
    "            out), np.max(out), np.min(out))\n",
    "\n",
    "    for i in range(src_len, src_len + (0 if DEBUG_DEBUG else LENGTH_OF_EACH)):\n",
    "        x = ctx[:i+1]\n",
    "        x = x[-model.ctx_len:]\n",
    "\n",
    "        if i == src_len:\n",
    "            out = copy.deepcopy(init_state.out) # load the RNN hidden state\n",
    "        else:\n",
    "            out = model.run(x) # run the RNN\n",
    "\n",
    "        out[0] = -999999999  # disable <|endoftext|>\n",
    "\n",
    "        char = sample_logits(out, temperature=TEMPERATURE, top_p=TOP_P)\n",
    "        char = char.item()\n",
    "        print(model.tokenizer.decode(char), end='', flush=True)\n",
    "\n",
    "        ctx += [char]\n",
    "    print('\\n' + '-' * 70, end='')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c581c4a7-acfc-41b2-ace4-f6375e7e7ea6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
